{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a3863617-6511-408a-9e21-4561a9c520dd",
   "metadata": {},
   "source": [
    "# Prompt+LLM\n",
    "基本构成： \n",
    "PromptTemplate / ChatPromptTemplate -> LLM / ChatModel -> OutputParser"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcca77b3-ff1e-4011-b938-0d4895aa498e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.prompts import ChatPromptTemplate\n",
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "prompt = ChatPromptTemplate.from_template(\"给我讲一个关于{foo}的笑话\")\n",
    "model = ChatOpenAI(\n",
    "    temperature=0,\n",
    "    model=\"gpt-3.5-turbo\",\n",
    ")\n",
    "chain = prompt | model\n",
    "\n",
    "chain.invoke({\"foo\": \"狗熊\"})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b19b860-79be-4b2a-8f61-bc913d81eb50",
   "metadata": {},
   "source": [
    "## 自定义停止输出符"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e565c568-de2a-4a09-9e65-f56d3d5153a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 遇到换行符自动停止\n",
    "chain = prompt | model.bind(stop=[\"\\n\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7db05b28-4ce4-4069-bf99-988e8bcbd605",
   "metadata": {},
   "source": [
    "## 兼容openai函数调用的方式"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c393d15-3a75-4a63-9380-58def37df674",
   "metadata": {},
   "outputs": [],
   "source": [
    "functions = [\n",
    "    {\n",
    "        \"name\": \"joke\",\n",
    "        \"description\": \"讲笑话\",\n",
    "        \"parameters\": {\n",
    "            \"type\": \"object\",\n",
    "            \"properties\": {\n",
    "                \"setup\": {\"type\": \"string\", \"description\": \"笑话的开头\"},\n",
    "                \"punchline\": {\n",
    "                    \"type\": \"string\",\n",
    "                    \"description\": \"爆梗的结尾\",\n",
    "                },\n",
    "            },\n",
    "            \"required\": [\"setup\", \"punchline\"],\n",
    "        },\n",
    "    }\n",
    "]\n",
    "chain = prompt | model.bind(function_call={\"name\": \"joke\"}, functions=functions)\n",
    "\n",
    "chain.invoke({\"foo\": \"男人\"}, config={})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc2cd7a6-4735-4ff4-ab66-6f9e067c5955",
   "metadata": {},
   "source": [
    "## 输出解析器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33c0fc8f-23aa-414c-bac5-5a211fe386ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.output_parsers import StrOutputParser\n",
    "\n",
    "chain = prompt | model | StrOutputParser()\n",
    "\n",
    "chain.invoke({\"foo\": \"女人\"})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0745e34d-b7a3-4568-bd3e-2ba948f58b63",
   "metadata": {},
   "source": [
    "## 与函数调用混合使用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1beec247-82a3-43db-9c65-23c97915af95",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser\n",
    "\n",
    "chain = (\n",
    "    prompt\n",
    "    | model.bind(function_call={\"name\": \"joke\"}, functions=functions)\n",
    "    | JsonOutputFunctionsParser()\n",
    ")\n",
    "\n",
    "chain.invoke({\"foo\": \"女人\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e4be91c-abdf-4095-a819-c71eb9461a4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#只输出setup\n",
    "from langchain.output_parsers.openai_functions import JsonKeyOutputFunctionsParser\n",
    "from langchain_core.runnables import RunnableParallel, RunnablePassthrough\n",
    "\n",
    "chain = (\n",
    "    {\"foo\": RunnablePassthrough()} #使用RunnablePassthrough()跳过prompt\n",
    "    | prompt\n",
    "    | model.bind(function_call={\"name\": \"joke\"}, functions=functions)\n",
    "    | JsonKeyOutputFunctionsParser(key_name=\"punchline\") # 定义输出的key\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d0e48a1-26c9-4d10-8434-6a1415bbbf8d",
   "metadata": {},
   "source": [
    "## 使用Runnables来连接多链结构"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d1c4a1f-988e-4ebe-8270-1c4d1324ec2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from operator import itemgetter #获取可迭代对象中指定索引或键对应的元素\n",
    "\n",
    "from langchain.schema import StrOutputParser\n",
    "from langchain_core.prompts import ChatPromptTemplate\n",
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "prompt1 = ChatPromptTemplate.from_template(\"{person}来自于哪个城市?\")\n",
    "prompt2 = ChatPromptTemplate.from_template(\n",
    "    \"{city}属于哪个省? 用{language}来回答\"\n",
    ")\n",
    "\n",
    "model = ChatOpenAI()\n",
    "\n",
    "chain1 = prompt1 | model | StrOutputParser()\n",
    "\n",
    "chain2 = (\n",
    "    {\"city\": chain1, \"language\": itemgetter(\"language\")} #获取invoke中的language\n",
    "    | prompt2\n",
    "    | model\n",
    "    | StrOutputParser()\n",
    ")\n",
    "chain1.invoke({\"person\": \"马化腾\"})\n",
    "chain2.invoke({\"person\": \"马化腾\", \"language\": \"中文\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc1770f9-04be-470c-a5ea-67cedbc0f6a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.runnables import RunnablePassthrough\n",
    "\n",
    "prompt1 = ChatPromptTemplate.from_template(\n",
    "    \"生成一个{attribute}属性的颜色。除了返回这个颜色的名字不要做其他事:\"\n",
    ")\n",
    "prompt2 = ChatPromptTemplate.from_template(\n",
    "    \"什么水果是这个颜色:{color},只返回这个水果的名字不要做其他事情:\"\n",
    ")\n",
    "prompt3 = ChatPromptTemplate.from_template(\n",
    "    \"哪个国家的国旗有这个颜色:{color},只返回这个国家的名字不要做其他事情:\"\n",
    ")\n",
    "prompt4 = ChatPromptTemplate.from_template(\n",
    "    \"有这个颜色的水果是{fruit},有这个颜色的国旗是{country}？\"\n",
    ")\n",
    "\n",
    "model_parser = model | StrOutputParser()\n",
    "# 生成一个颜色\n",
    "color_generator = (\n",
    "    {\"attribute\": RunnablePassthrough()} | prompt1 | {\"color\": model_parser}\n",
    ")\n",
    "color_to_fruit = prompt2 | model_parser\n",
    "color_to_country = prompt3 | model_parser\n",
    "question_generator = (\n",
    "    color_generator | {\"fruit\": color_to_fruit, \"country\": color_to_country} | prompt4\n",
    ")\n",
    "\n",
    "question_generator.invoke(\"强烈的\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (langchain)",
   "language": "python",
   "name": "langchain-env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
